from mooquant import bar, strategy
from mooquant.barfeed.tusharefeed import Feed
from mooquant.broker.backtesting import TradePercentage
from mooquant.broker.fillstrategy import DefaultStrategy
from mooquant.technical import cross, ma


class thrSMA(strategy.BacktestingStrategy):
    def __init__(self, feed, instrument, short_l, mid_l, long_l, up_cum):
        strategy.BacktestingStrategy.__init__(self, feed)

        self.getBroker().setFillStrategy(DefaultStrategy(None))
        self.getBroker().setCommission(TradePercentage(0.001))

        self.__position = None
        self.__instrument = instrument
        self.__prices = feed[instrument].getPriceDataSeries()

        self.__malength1 = int(short_l)
        self.__malength2 = int(mid_l)
        self.__malength3 = int(long_l)

        self.__circ = int(up_cum)

        self.__ma1 = ma.SMA(self.__prices, self.__malength1)
        self.__ma2 = ma.SMA(self.__prices, self.__malength2)
        self.__ma3 = ma.SMA(self.__prices, self.__malength3)

    def getPrice(self):
        return self.__prices

    def getSMA(self):
        return self.__ma1, self.__ma2, self.__ma3

    def onEnterCanceled(self, position):
        self.__position = None

    def onEnterOK(self):
        pass

    def onExitOk(self, position):
        self.__position = None

    def onExitCanceled(self, position):
        self.__position.exitMarket()

    def buyCon1(self):
        if cross.cross_above(self.__ma1, self.__ma2) > 0:
            return True

    def buyCon2(self):
        m1 = 0
        m2 = 0

        for i in range(self.__circ):
            if self.__ma1[-i - 1] > self.__ma3[-i - 1]:
                m1 += 1

            if self.__ma2[-i - 1] > self.__ma3[-i - 1]:
                m2 += 1

        if m1 >= self.__circ and m2 >= self.__circ:
            return True

    def sellCon1(self):
        if cross.cross_below(self.__ma1, self.__ma2) > 0:
            return True

    def onBars(self, bars):
        # If a position was not opened, check if we should enter a long
        # position.

        if self.__ma2[-1] is None:
            return

        if self.__position is not None:
            if not self.__position.exitActive() and cross.cross_below(self.__ma1, self.__ma2) > 0:
                self.__position.exitMarket()

        if self.__position is None:
            if self.buyCon1() and self.buyCon2():
                shares = int(self.getBroker().getCash() * 0.2 / bars[self.__instrument].getPrice())

                self.__position = self.enterLong(self.__instrument, shares)

                print(bars[self.__instrument].getDateTime(),
                      bars[self.__instrument].getPrice())
                # self.info("buy %s" % (bars.getDateTime()))


def runStratOnTushare(strat, paras, security_id, market, frequency):
    liveFeed = Feed([security_id], frequency, 1024, 5)
    strat = strat(liveFeed, security_id, *paras)
    strat.run()


if __name__ == "__main__":
    strat = thrSMA
    security_id = '600848'
    market = 'SH'
    frequency = bar.Frequency.MINUTE
    paras = [2, 20, 60, 10]

    runStratOnTushare(strat, paras, security_id, market, frequency)
